import copy
from pathlib import Path
from typing import Dict, Optional, Union

import torch
from accelerate.utils import extract_model_from_parallel
from safetensors.torch import save_file
from torch import nn
from torch.utils import data

from kronfluence.arguments import FactorArguments
from kronfluence.computer.factor_computer import FactorComputer
from kronfluence.computer.score_computer import ScoreComputer
from kronfluence.module.utils import wrap_tracked_modules
from kronfluence.task import Task
from kronfluence.utils.dataset import DataLoaderKwargs
from kronfluence.utils.save import load_file, verify_models_equivalence


def prepare_model(
    model: nn.Module,
    task: Task,
) -> nn.Module:

    model.eval()
    for params in model.parameters():
        params.requires_grad = False
    for buffers in model.buffers():
        buffers.requires_grad = False

    # Install `TrackedModule` wrappers on supported modules.
    model = wrap_tracked_modules(model=model, task=task)
    return model


class Analyzer(FactorComputer, ScoreComputer):


    def __init__(
        self,
        analysis_name: str,
        model: nn.Module,
        task: Task,
        cpu: bool = False,
        log_level: Optional[int] = None,
        log_main_process_only: bool = True,
        profile: bool = False,
        disable_tqdm: bool = False,
        output_dir: str = "./influence_results",
        disable_model_save: bool = True,
    ) -> None:

        super().__init__(
            name=analysis_name,
            model=model,
            task=task,
            cpu=cpu,
            log_level=log_level,
            log_main_process_only=log_main_process_only,
            profile=profile,
            disable_tqdm=disable_tqdm,
            output_dir=output_dir,
        )
        self.logger.info(f"Initializing `Analyzer` with parameters: {locals()}")
        self.logger.info(f"Process state configuration:\n{repr(self.state)}")

        # Save model parameters if necessary.
        if self.state.is_main_process and not disable_model_save:
            self._save_model()
        self.state.wait_for_everyone()

    def set_dataloader_kwargs(self, dataloader_kwargs: DataLoaderKwargs) -> None:

        self._dataloader_params = dataloader_kwargs

    @torch.no_grad()
    def _save_model(self) -> None:
        """Saves the model to the output directory."""
        model_save_path = self.output_dir / "model.safetensors"
        extracted_model = extract_model_from_parallel(model=copy.deepcopy(self.model), keep_fp32_wrapper=True)

        if model_save_path.exists():
            self.logger.info(f"Found existing saved model at `{model_save_path}`.")
            # Load existing model's `state_dict` for comparison.
            loaded_state_dict = load_file(model_save_path)
            if not verify_models_equivalence(loaded_state_dict, extracted_model.state_dict()):
                error_msg = (
                    "Detected a difference between the current model and the one saved at "
                    f"`{model_save_path}`. Consider using a different `analysis_name` to avoid conflicts."
                )
                self.logger.error(error_msg)
                raise ValueError(error_msg)
        else:
            self.logger.info(f"No existing model found at `{model_save_path}`.")
            state_dict = extracted_model.state_dict()
            state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
            save_file(state_dict, model_save_path)
            self.logger.info(f"Saved model at `{model_save_path}`.")

    def fit_all_factors(
        self,
        factors_name: str,
        dataset: data.Dataset,
        per_device_batch_size: Optional[int] = None,
        initial_per_device_batch_size_attempt: int = 4096,
        dataloader_kwargs: Optional[DataLoaderKwargs] = None,
        factor_args: Optional[FactorArguments] = None,
        overwrite_output_dir: bool = False,
    ) -> None:

        self.fit_covariance_matrices(
            factors_name=factors_name,
            dataset=dataset,
            per_device_batch_size=per_device_batch_size,
            initial_per_device_batch_size_attempt=initial_per_device_batch_size_attempt,
            dataloader_kwargs=dataloader_kwargs,
            factor_args=factor_args,
            overwrite_output_dir=overwrite_output_dir,
        )
        self.perform_eigendecomposition(
            factors_name=factors_name,
            factor_args=factor_args,
            overwrite_output_dir=overwrite_output_dir,
        )
        self.fit_lambda_matrices(
            factors_name=factors_name,
            dataset=dataset,
            per_device_batch_size=per_device_batch_size,
            initial_per_device_batch_size_attempt=initial_per_device_batch_size_attempt,
            dataloader_kwargs=dataloader_kwargs,
            factor_args=factor_args,
            overwrite_output_dir=overwrite_output_dir,
        )

    @staticmethod
    def load_file(path: Union[str, Path]) -> Dict[str, torch.Tensor]:

        if isinstance(path, str):
            path = Path(path).resolve()
        if not path.exists():
            raise FileNotFoundError(f"File not found: {path}.")
        return load_file(path)

    @staticmethod
    def get_module_summary(model: nn.Module) -> str:

        format_str = "==Model Summary=="
        for module_name, module in model.named_modules():
            if len(list(module.children())) > 0:
                continue
            if len(list(module.parameters())) == 0:
                continue
            format_str += f"\nModule Name: `{module_name}`, Module: {repr(module)}"
        return format_str
